//
// Created by neo on 25-4-3.
//

#include "MedianFilter.h"

#include <core/image/filters/BasicFilter.h>
#include <memory>
#include <runtime/config.h>
#include <runtime/gpu/VkGPUContext.h>
#include <vulkan/vulkan_core.h>

#ifdef OS_OPEN_HARMONY
#include <runtime/gpu/utils/vk_enum_string_helper.h>
#else
#include <vulkan/vk_enum_string_helper.h>
#endif

#include "runtime/gpu/VkGPUHelper.h"
#include "runtime/gpu/compute_graph/ComputePipelineNode.h"
#include "runtime/log/Log.h"

std::shared_ptr<SubComputeGraph> MedianFilter::CreateParallelSubGraph(
    const size_t parallelIndex, const std::shared_ptr<VkGPUContext> &gpuCtx,
    const VkBuffer inputBuffer, const VkDeviceSize inputBufferSize,
    const VkBuffer outputBuffer, const VkDeviceSize outputBufferSize) {
  const auto computeSubGraph = std::make_shared<SubComputeGraph>(gpuCtx);
  VkResult ret = computeSubGraph->Init();
  if (ret != VK_SUCCESS) {
    Logger() << "Failed to create compute graph, err =" << string_VkResult(ret)
             << std::endl;
    return nullptr;
  }

  PushConstantInfo pushConstantInfo;
  pushConstantInfo.size = sizeof(MedianFilterParams);
  pushConstantInfo.data = &this->medianFilterParams[parallelIndex];

  PipelineNodeBuffer pipelineNodeInput;
  pipelineNodeInput.type = PIPELINE_NODE_BUFFER_STORAGE_READ;
  pipelineNodeInput.buf.buffer = inputBuffer;
  pipelineNodeInput.buf.bufferSize = inputBufferSize;

  PipelineNodeBuffer pipelineNodeOutput;
  pipelineNodeOutput.type = PIPELINE_NODE_BUFFER_STORAGE_WRITE;
  pipelineNodeOutput.buf.buffer = outputBuffer;
  pipelineNodeOutput.buf.bufferSize = outputBufferSize;

  std::vector<PipelineNodeBuffer> pipelineBuffers;
  pipelineBuffers.push_back(pipelineNodeInput);
  pipelineBuffers.push_back(pipelineNodeOutput);

  std::vector<VkDescriptorSetLayoutBinding> descriptorSetLayoutBindings;
  descriptorSetLayoutBindings.push_back(
      VkGPUHelper::BuildDescriptorSetLayoutBinding(
          0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1,
          VK_SHADER_STAGE_COMPUTE_BIT));
  descriptorSetLayoutBindings.push_back(
      VkGPUHelper::BuildDescriptorSetLayoutBinding(
          1, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1,
          VK_SHADER_STAGE_COMPUTE_BIT));

  const auto node = std::make_shared<ComputePipelineNode>(
      gpuCtx, "MedianFilter", SHADER(midvalue.comp.glsl.spv),
      pushConstantInfo.size, descriptorSetLayoutBindings,
      (this->medianFilterParams[parallelIndex].imageSize.width + 31) / 32,
      (this->medianFilterParams[parallelIndex].imageSize.height + 31) / 32, 1);
  ret = node->CreateComputeGraphNode();
  if (ret != VK_SUCCESS) {
    Logger() << "Failed to create compute graph, err =" << string_VkResult(ret)
             << std::endl;
    return nullptr;
  }
  node->AddComputeElement(
      {.pushConstantInfo = pushConstantInfo, .buffers = pipelineBuffers});
  computeSubGraph->AddComputeGraphNode(node);
  return computeSubGraph;
}

VkResult
MedianFilter::Apply(const std::shared_ptr<VkGPUContext> &gpuCtx,
                    const std::vector<FilterImageInfo> &inputImageInfo,
                    const std::vector<FilterImageInfo> &outputImageInfo) {
  uint32_t parallelSize;
  const std::vector<DeviceQueue> parallelQueues = gpuCtx->GetAllParallelQueue(
      VK_QUEUE_COMPUTE_BIT | VK_QUEUE_TRANSFER_BIT | VK_QUEUE_GRAPHICS_BIT);
  if (parallelQueues.empty()) {
    Logger() << "No parallel queues found!\n";
    parallelSize = 1;
  } else {
    parallelSize = parallelQueues.size();
    if (parallelSize > 4) {
      parallelSize = 4;
    }
    Logger() << "Parallel size:" << parallelSize
             << ", all queue: " << parallelQueues.size() << std::endl;
  }
  this->medianFilterParams.resize(parallelSize);
  this->computeGraph = std::make_shared<ComputeGraph>(gpuCtx);
  for (size_t parallelIndex = 0; parallelIndex < parallelSize;
       parallelIndex++) {
    this->medianFilterParams[parallelIndex].imageSize.width =
        inputImageInfo[0].width;
    this->medianFilterParams[parallelIndex].imageSize.height =
        inputImageInfo[0].height;
    this->medianFilterParams[parallelIndex].imageSize.channels = 4;
    this->medianFilterParams[parallelIndex].imageSize.bytesPerLine =
        this->medianFilterParams[parallelIndex].imageSize.width * 4;
    this->medianFilterParams[parallelIndex].pieceCount = parallelSize;
    this->medianFilterParams[parallelIndex].piece = parallelIndex;
    this->medianFilterParams[parallelIndex].radius = radius;

    const std::shared_ptr<SubComputeGraph> parallelGraph =
        CreateParallelSubGraph(
            parallelIndex, gpuCtx, inputImageInfo[0].storageBuffer,
            inputImageInfo[0].bufferSize, outputImageInfo[0].storageBuffer,
            outputImageInfo[0].bufferSize);
    if (parallelGraph == nullptr) {
      return VK_ERROR_INITIALIZATION_FAILED;
    }
    computeGraph->AddSubGraph(parallelGraph);
  }

  return computeGraph->Compute();
}

void MedianFilter::Destroy() {
  if (computeGraph != nullptr) {
    this->computeGraph->Destroy();
    computeGraph = nullptr;
  }
}
